{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "c4a4cec7",
   "metadata": {},
   "outputs": [],
   "source": [
    "%run -i \"../util/util_simple_classifier.ipynb\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "9ab0c9ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "from nltk import word_tokenize\n",
    "from sklearn.feature_extraction.text import CountVectorizer\n",
    "from sklearn.metrics import classification_report"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "05963dfb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load cleaned data from disk\n",
    "train_df = pd.read_json(\"../data/rotten_tomatoes_train.json\")\n",
    "test_df = pd.read_json(\"../data/rotten_tomatoes_test.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "e593df8e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the positive and negative word lists\n",
    "# Only from training data\n",
    "positive_train_words = train_df[train_df[\"label\"] == 1].tokenized_text.sum()\n",
    "negative_train_words = train_df[train_df[\"label\"] == 0].tokenized_text.sum()\n",
    "word_intersection = set(positive_train_words) & set(negative_train_words)\n",
    "positive_filtered = list(set(positive_train_words) - word_intersection)\n",
    "negative_filtered = list(set(negative_train_words) - word_intersection)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "37cb3157",
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_vectorizers(word_lists):\n",
    "    vectorizers = []\n",
    "    for word_list in word_lists:\n",
    "        vectorizer = CountVectorizer(vocabulary=word_list)\n",
    "        vectorizers.append(vectorizer)\n",
    "    return vectorizers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "9c7cccb6",
   "metadata": {},
   "outputs": [],
   "source": [
    "vectorizers = create_vectorizers([negative_filtered, positive_filtered])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "79cb2dce",
   "metadata": {},
   "outputs": [],
   "source": [
    "def vectorize(text_list, vectorizers):\n",
    "    text = \" \".join(text_list)\n",
    "    scores = []\n",
    "    for vectorizer in vectorizers:\n",
    "        output = vectorizer.transform([text])\n",
    "        output_sum = sum(output.todense().tolist()[0])\n",
    "        scores.append(output_sum)\n",
    "    return scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "9f8ac8f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def classify(score_list):\n",
    "    return max(enumerate(score_list),key=lambda x: x[1])[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "4d58a705",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                                                   text  label lang  \\\n",
      "0     the rock is destined to be the 21st century's ...      1   en   \n",
      "1     the gorgeously elaborate continuation of \" the...      1   en   \n",
      "2                        effective but too-tepid biopic      1   en   \n",
      "3     if you sometimes like to go to the movies to h...      1   en   \n",
      "4     emerges as something rare , an issue movie tha...      1   en   \n",
      "...                                                 ...    ...  ...   \n",
      "8525  any enjoyment will be hinge from a personal th...      0   en   \n",
      "8526  if legendary shlockmeister ed wood had ever ma...      0   en   \n",
      "8527  hardly a nuanced portrait of a young woman's b...      0   en   \n",
      "8528    interminably bleak , to say nothing of boring .      0   en   \n",
      "8529  things really get weird , though not particula...      0   en   \n",
      "\n",
      "                                         tokenized_text  prediction  \n",
      "0     [rock, destined, 21st, century, new, conan, go...           1  \n",
      "1     [gorgeously, elaborate, continuation, lord, ri...           1  \n",
      "2                        [effective, too-tepid, biopic]           0  \n",
      "3     [sometimes, like, go, movies, fun, wasabi, goo...           0  \n",
      "4     [emerges, something, rare, issue, movie, hones...           1  \n",
      "...                                                 ...         ...  \n",
      "8525  [enjoyment, hinge, personal, threshold, watchi...           0  \n",
      "8526  [legendary, shlockmeister, ed, wood, ever, mad...           0  \n",
      "8527  [hardly, nuanced, portrait, young, woman, brea...           0  \n",
      "8528        [interminably, bleak, say, nothing, boring]           0  \n",
      "8529  [things, really, get, weird, though, particula...           0  \n",
      "\n",
      "[8352 rows x 5 columns]\n"
     ]
    }
   ],
   "source": [
    "train_df[\"prediction\"] = train_df[\"tokenized_text\"].apply(lambda x: classify(vectorize(x, vectorizers)))\n",
    "print(train_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "9ef532b8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0       0.79      0.99      0.88      4185\n",
      "           1       0.99      0.74      0.85      4167\n",
      "\n",
      "    accuracy                           0.87      8352\n",
      "   macro avg       0.89      0.87      0.86      8352\n",
      "weighted avg       0.89      0.87      0.86      8352\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Training metrics\n",
    "print(classification_report(train_df['label'], train_df['prediction']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "c0a21465",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0       0.59      0.81      0.68       523\n",
      "           1       0.69      0.43      0.53       522\n",
      "\n",
      "    accuracy                           0.62      1045\n",
      "   macro avg       0.64      0.62      0.61      1045\n",
      "weighted avg       0.64      0.62      0.61      1045\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Test metrics\n",
    "test_df[\"prediction\"] = test_df[\"tokenized_text\"].apply(lambda x: classify(vectorize(x, vectorizers)))\n",
    "print(classification_report(test_df['label'], test_df['prediction']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "26c83f8f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0       0.79      0.99      0.88      4265\n",
      "           1       0.99      0.74      0.84      4265\n",
      "\n",
      "    accuracy                           0.86      8530\n",
      "   macro avg       0.89      0.86      0.86      8530\n",
      "weighted avg       0.89      0.86      0.86      8530\n",
      "\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0       0.59      0.81      0.68       533\n",
      "           1       0.70      0.44      0.54       533\n",
      "\n",
      "    accuracy                           0.62      1066\n",
      "   macro avg       0.65      0.62      0.61      1066\n",
      "weighted avg       0.65      0.62      0.61      1066\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Compare to performance on uncleaned data\n",
    "(train_df, test_df) = load_train_test_dataset_pd(\"train\", \"test\")\n",
    "train_df[\"text\"] = train_df[\"text\"].apply(word_tokenize)\n",
    "test_df[\"text\"] = test_df[\"text\"].apply(word_tokenize)\n",
    "positive_train_words = train_df[train_df[\"label\"] == 1].text.sum()\n",
    "negative_train_words = train_df[train_df[\"label\"] == 0].text.sum()\n",
    "word_intersection = set(positive_train_words) & set(negative_train_words)\n",
    "positive_filtered = list(set(positive_train_words) - word_intersection)\n",
    "negative_filtered = list(set(negative_train_words) - word_intersection)\n",
    "vectorizers = create_vectorizers([negative_filtered, positive_filtered])\n",
    "train_df[\"prediction\"] = train_df[\"text\"].apply(lambda x: classify(vectorize(x, vectorizers)))\n",
    "test_df[\"prediction\"] = test_df[\"text\"].apply(lambda x: classify(vectorize(x, vectorizers)))\n",
    "print(classification_report(train_df['label'], train_df['prediction']))\n",
    "print(classification_report(test_df['label'], test_df['prediction']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8b829d2",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
